Tensor Puzzles - Penzai Edition¶
- by Sasha Rush - srush_nlp
This is a version of the tensor puzzles implemented the JAX Penzai library.
Penzai is a really nice fit for these puzzles both because it comes with a really clean visualization library built-in and because it has a very nice named-tensor implementation.
I recommend running in Colab. Click here and copy the notebook to get start.
#!pip install -qqq jaxtyping hypothesis pytest penzai
import jax.numpy as np
import numpy as onp
from penzai import pz
arange = pz.nx.arange
where = pz.nx.nmap(np.where)
wrap = pz.nx.wrap
pz.ts.register_as_default()
# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer(force_continuous=True, around_zero=True, prefers_column=["j"], prefers_row=["i"]))
import inspect
import random
from jaxtyping import Int32
NamedArray = pz.nx.NamedArray
def make_test(name, problem, problem_spec, add_sizes=[],
init_size = {},
constraint=lambda d: d):
args = {}
signature = inspect.signature(problem)
for n, p in signature.parameters.items():
args[n] = [d.name for d in p.annotation.dims]
args["return"] = [d.name for d in signature.return_annotation.dims]
def make_instance():
example = {}
reg = {}
sizes = {}
for k in init_size:
sizes[k] = init_size[k]
for n in args:
size = {}
for name in args[n]:
if name[0] not in sizes:
sizes[name[0]] = random.randint(2, 7)
size[name] = sizes[name[0]]
if "_s" in n:
l = list(size.keys())[0]
example[n] = pz.nx.arange(l, size[l])
else:
v = onp.random.randint(-5, 5, list(size.values()))
example[n] = pz.nx.wrap(v).tag(*args[n])
example = constraint(example)
for n in args:
x = example[n]
x = x.untag(*args[n])
reg[n] = x.unwrap().tolist()
if len(args[n]) == 0:
reg[n] = [0]
return example, reg
examples = []
correct = 0
for i in range(3):
example, reg = make_instance()
# out = example["return"].tolist()
del example["return"]
problem_spec(*reg.values())
if len(reg["return"]) == 1:
reg["return"] = reg["return"][0]
yours = None
yours = problem(**example)
example["target"] = wrap(reg["return"])
example["target"] = example["target"].tag(*args["return"])
if yours is not None:
example["yours"] = yours
same = example["target"] == example["yours"]
if same.untag(*same.named_shape.keys()).unwrap().all():
correct += 1
examples.append(example)
if correct == 3:
print("Correct")
else:
print("Failure")
return examples
Rules¶
- Each puzzle needs to be solved in 1 line (<80 columns) of code.
- You are only allowed to use
contract,whereand indexing. - You are not allowed anything else. No
view,sum,take,squeeze,tensor.
# Example of named infix ops.
a = [arange("i", 4), arange("j", 5)]
b = [arange("j", 3), arange("i", 2)]
[{"a": a, "b":b, "ret": a + b} for a, b in zip(a, b)]
# Example of where
examples = [(wrap([False, True], "i"), wrap([1, 1], "i"), wrap([-1, 0], "i")),
(wrap([[False, True], [True, False]], "i", "j"), wrap([0, 1], "i"), wrap([-1, 0], "j")),
]
[{"q": q, "a":a, "b":b, "ret": where(q, a, b)} for q, a, b in examples]
# Example of contraction
def contract(n, *ts):
t = 1
for t2 in ts:
t = t * t2
return pz.nx.nmap(np.sum)(t.untag(n))
a = [arange("i", 4), arange("j", 5)]
b = [arange("j", 3), arange("i", 2)]
[{"a": a, "b":b, "ret": contract("i", a * b)} for a, b in zip(a, b)]
Puzzle 1 - ones¶
Compute ones - the vector of all ones.
def ones_spec(i_s, out):
for i in i_s:
out[i] = 1
def ones(i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
return i_s * 0 + 1
make_test("one", ones, ones_spec)
Correct
Puzzle 2 - sum¶
Compute sum - the sum of a vector.
def sum_spec(a, out):
for i in range(len(a)):
out[0] = out[0] + a[i]
def sum(a: Int32[NamedArray, "i"]) -> Int32[NamedArray, ""]:
return contract("i", a)
make_test("sum", sum, sum_spec)
Correct
Puzzle 3 - outer¶
Compute outer - the outer product of two vectors.
def outer_spec(a, b, out):
for i in range(len(out)):
for j in range(len(out[0])):
out[i][j] = a[i] * b[j]
def outer(a: Int32[NamedArray, "i"], b : Int32[NamedArray, "j"]) -> Int32[NamedArray, "i j"]:
return a * b
make_test("outer", outer, outer_spec)
Correct
Puzzle 4 - diag¶
Compute diag - the diagonal vector of a square matrix.
def diag_spec(a, i1_s, out):
for i in range(len(a)):
out[i] = a[i][i]
def diag(a: Int32[NamedArray, "i1 i2"], i1_s: Int32[NamedArray, "i1"]) -> Int32[NamedArray, "i1"]:
return a[{"i1": i1_s, "i2": i1_s}]
make_test("diag", diag, diag_spec)
Correct
Puzzle 5 - eye¶
Compute eye - the identity matrix.
def eye_spec(i1_s, i2_s, out):
for i in i1_s:
for j in i2_s:
if i == j:
out[i][j] = 1
else:
out[i][j] = 0
def eye(i1_s: Int32[NamedArray, "i1"], i2_s : Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i1 i2"]:
return where(i1_s == i2_s, 1, 0)
make_test("eye", eye, eye_spec)
Correct
Puzzle 6 - triu¶
Compute triu - the upper triangular matrix.
def triu_spec(i1_s, i2_s, out):
for i in i1_s:
for j in i2_s:
if i <= j:
out[i][j] = 1
else:
out[i][j] = 0
def triu(i1_s: Int32[NamedArray, "i1"], i2_s : Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i1 i2"]:
return where(i1_s <= i2_s, 1, 0)
make_test("triu", triu, triu_spec)
Correct
Puzzle 7 - cumsum¶
Compute cumsum - the cumulative sum.
def cumsum_spec(a, i1_s, i2_s, out):
total = 0
for i in range(len(out)):
out[i] = total + a[i]
total += a[i]
def cumsum(a: Int32[NamedArray, "i1"], i1_s : Int32[NamedArray, "i1"], i2_s: Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i2"]:
return contract("i1", where(i1_s <= i2_s, 1, 0), a)
make_test("cumsum", cumsum, cumsum_spec)
Correct
Puzzle 8 - diff¶
Compute diff - the running difference.
def diff_spec(a, i_s, out):
out[0] = a[0]
for i in range(0, len(out)):
out[i] = a[i] - a[(i - 1)]
def diff(a: Int32[NamedArray, "i"], i1_s : Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
return a - a[{"i": i1_s - 1}]
make_test("diff", diff, diff_spec)
Correct
Puzzle 9 - stack¶
Compute vstack - the matrix of two vectors
def stack_spec(a, b, out):
for i in range(len(out[0])):
out[0][i] = a[i]
out[1][i] = b[i]
def stack(a: Int32[NamedArray, "i"], b: Int32[NamedArray, "i"]) -> Int32[NamedArray, "j i"]:
return where(arange("j", 2) == 1, b, a)
make_test("stack", stack, stack_spec, init_size={"j" : 2})
Correct
Puzzle 10 - roll¶
Compute roll - the vector shifted 1 circular position.
def roll_spec(a, i_s, out):
for i in range(len(out)):
if i + 1 < len(out):
out[i] = a[i + 1]
else:
out[i] = a[i + 1 - len(out)]
def roll(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
return a[{"i": (i_s + 1) % i_s.named_shape["i"]}]
make_test("roll", roll, roll_spec)
Correct
Puzzle 11 - flip¶
Compute flip - the reversed vector
def flip_spec(a, i_s, out):
for i in range(len(out)):
out[i] = a[len(out) - i - 1]
def flip(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"]) -> Int32[NamedArray, "i"]:
return a[{"i": (-i_s - 1)}]
make_test("flip", flip, flip_spec, add_sizes=["i"])
Correct
Puzzle 12 - compress¶
Compute compress - keep only masked entries (left-aligned).
def compress_spec(g, v, i1_s, i2_s, out):
j = 0
for i in range(len(out)):
out[i] = 0
for i in range(len(g)):
if g[i] > 1:
out[j] = v[i]
j += 1
def compress(g: Int32[NamedArray, "i1"], v: Int32[NamedArray, "i2"], i1_s:Int32[NamedArray, "i1"], i2_s:Int32[NamedArray, "i2"]) -> Int32[NamedArray, "i2"]:
# I don't know how to do this one!
return g
make_test("compress", compress, compress_spec)
Failure
Puzzle 13 - pad_to¶
Compute pad_to - eliminate or add 0s to change size of vector.
def pad_to_spec(a, i_s, j_s, out):
for i in range(len(out)):
if i < len(a):
out[i] = a[i]
else:
out[i] = 0
def pad_to(a: Int32[NamedArray, "i"], i_s:Int32[NamedArray, "i"], j_s:Int32[NamedArray, "j"]) -> Int32[NamedArray, "j"]:
return contract("i", a, where(j_s == i_s, 1, 0))
make_test("pad_to", pad_to, pad_to_spec)
Correct
Puzzle 14 - sequence_mask¶
Compute sequence_mask - pad out to length per batch.
# Didn't do
# def sequence_mask_spec(values, length, out):
# for i in range(len(out)):
# for j in range(len(out[0])):
# if j < length[i]:
# out[i][j] = values[i][j]
# else:
# out[i][j] = 0
# def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
# pass
# def constraint_set_length(d):
# d["length"] = d["length"] % d["values"].shape[1]
# return d
# make_test("sequence_mask",
# sequence_mask, sequence_mask_spec, constraint=constraint_set_length
# )
Puzzle 15 - bincount¶
Compute bincount - count number of times an entry was seen.
def bincount_spec(a, i_s, j1_s, j2_s, out):
for i in range(len(out)):
out[i] = 0
for i in range(len(a)):
out[a[i]] += 1
def bincount(a: Int32[NamedArray, "i"], i_s: Int32[NamedArray, "i"],
j1_s: Int32[NamedArray, "j1"], j2_s: Int32[NamedArray, "j2"]) -> Int32[NamedArray, "j2"]:
return contract("i", eye(j1_s, j2_s)[{"j1": a}])
def constraint_set_max(d):
d["a"] = d["a"] % d["return"].named_shape["j2"]
return d
make_test("bincount",
bincount, bincount_spec, constraint=constraint_set_max
)
Correct
Puzzle 16 - scatter_add¶
Compute scatter_add - add together values that link to the same location.
def scatter_add_spec(values, link, j1_s, j2_s, out):
for i in range(len(out)):
out[i] = 0
for j in range(len(values)):
out[link[j]] += values[j]
def scatter_add(values: Int32[NamedArray, "i"], link: Int32[NamedArray,"i"],
j1_s: Int32[NamedArray, "j1"], j2_s: Int32[NamedArray, "j2"]) -> Int32[NamedArray, "j2"]:
return contract("i", values, eye(j1_s, j2_s)[{"j1": link}])
def constraint_set_max(d):
d["link"] = d["link"] % d["return"].named_shape["j2"]
return d
make_test("scatter_add",
scatter_add, scatter_add_spec, constraint=constraint_set_max
)
Correct
import inspect
fns = (ones, sum, outer, diag, eye, triu, cumsum, diff, stack, roll, flip,
compress, pad_to, bincount, scatter_add)
for fn in fns:
lines = [l for l in inspect.getsource(fn).split("\n") if not l.strip().startswith("#")]
if len(lines) > 3:
print(fn.__name__, len(lines[2]), "(more than 1 line)")
else:
print(fn.__name__, len(lines[1]))
ones 22 sum 27 outer 16 diag 38 eye 36 triu 36 cumsum 55 diff 33 stack 43 roll 53 flip 31 compress 12 pad_to 52 bincount 52 (more than 1 line) scatter_add 63 (more than 1 line)